xxxxxxxxxx## TODO: - Normalize targets after clamping - Helps get training MSE down to ~0 when training on the first two samplesNotebook by Paul Scotti with code adapted from Aidan Dempster (https://github.com/Veldrovive/open_clip)
In particular, please somebody try out the various networks Aidan shared (https://github.com/Veldrovive/open_clip/blob/main/src/open_clip/model.py) which includes more complex architectures like transformers and architectures that handle both 2D and 3D voxels.
I also have a DistributedDataParallel version of this notebook for anyone who might want to use this with multi-gpu on Slurm (just ask me for it).
xxxxxxxxxx!nvidia-smiWed Nov 16 16:59:13 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA RTX A6000 Off | 00000000:01:00.0 Off | Off |
| 30% 52C P5 85W / 300W | 2MiB / 49140MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Import packages & functions¶
xxxxxxxxxx# You will need to download files from huggingface and change the respective paths to those files# https://huggingface.co/datasets/pscotti/naturalscenesdataset/tree/mainxxxxxxxxxx#!pip install "git+https://github.com/openai/CLIP.git@main#egg=clip"#!pip install git+https://github.com/openai/CLIP.gitxxxxxxxxxx#!pip install info-nce-pytorchxxxxxxxxxximport osimport sysimport mathimport numpy as npimport pandas as pdfrom matplotlib import pyplot as pltimport seaborn as snssns.set(font_scale=1.0)import torchfrom torch import nnimport torchvisionfrom torchvision import transformsfrom tqdm import tqdmimport PILfrom datetime import datetimeimport h5pyimport webdataset as wdsfrom info_nce import InfoNCEimport clipimport timefrom collections import OrderedDictfrom glob import globfrom PIL import Imagexxxxxxxxxximport osimport sysimport mathimport numpy as npfrom matplotlib import pyplot as pltimport torchfrom torch import nnimport torchvisionfrom torchvision import transformsfrom tqdm import tqdmimport PILfrom datetime import datetimeimport h5pyimport webdataset as wdsfrom info_nce import InfoNCEimport clipdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(device)mean=np.array([0.48145466, 0.4578275, 0.40821073])std=np.array([0.26862954, 0.26130258, 0.27577711])denorm = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())def np_to_Image(x): return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))def torch_to_Image(x,device=device): x = denorm(x) return transforms.ToPILImage()(x)def Image_to_torch(x): return (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8): #https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps)) return torch.div(numerator, denominator)def batchwise_cosine_similarity(Z,B): # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc B = B.T Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1). B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b). cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T return cosine_similaritydef get_non_diagonals(a): a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1) # make diagonals -1 a=a.fill_diagonal_(-1) return adef topk(similarities,labels,k=5): if k > similarities.shape[0]: k = similarities.shape[0] topsum=0 for i in range(k): topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels) return topsumxxxxxxxxxxdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(device)mean = np.array([0.48145466, 0.4578275, 0.40821073])std = np.array([0.26862954, 0.26130258, 0.27577711])denorm = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())def np_to_Image(x): return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))def torch_to_Image(x,device=device): x = denorm(x) return transforms.ToPILImage()(x)def Image_to_torch(x): return (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8): #https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps)) return torch.div(numerator, denominator)def batchwise_cosine_similarity(Z, B): # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc B = B.T Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1). B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b). cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T return cosine_similaritydef get_non_diagonals(a): a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1) # make diagonals -1 a=a.fill_diagonal_(-1) return adef topk(similarities,labels,k=5): if k > similarities.shape[0]: k = similarities.shape[0] topsum=0 for i in range(k): topsum += torch.sum(torch.argsort(similarities, axis=1)[:,-(i+1)] == labels)/len(labels) return topsumdef get_preprocs(): preproc_vox = transforms.Compose([transforms.ToTensor(), torch.nan_to_num]) preproc_img = transforms.Compose([ transforms.Resize(size=(224,224)), transforms.Normalize(mean=mean, std=std), ]) return preproc_vox, preproc_imgWhich pretrained model are you using for voxel alignment to embedding space?¶
xxxxxxxxxxmodel_name = 'clip_image_vit' # CLIP ViT-L/14 image embeddings# model_name = 'clip_text_vit' # CLIP ViT-L/14 text embeddings# model_name = 'clip_image_resnet' # CLIP basic ResNet image embeddingsprint(f"Using model: {model_name}")xxxxxxxxxx# dont want to train modelmodel.eval()# dont need to calculate gradientsfor param in model.parameters(): param.requires_grad = Falseif model_name=='clip_text_vit': f = h5py.File('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_subj_indices.hdf5', 'r') subj01_order = f['subj01'][:] image_features = model.encode_image(image.to(device)) if "vit" in model_name: # I think this is the clamping used by Lin Sprague Singh preprint image_features = torch.clamp(image_features,-1.5,1.5) return image_features #print(model)xxxxxxxxxx# dont want to train modelmodel.eval()# dont need to calculate gradientsfor param in model.parameters(): param.requires_grad = Falseif model_name == 'clip_text_vit': f = h5py.File('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_subj_indices.hdf5', 'r') subj01_order = f['subj01'][:] image_features = model.encode_image(image.to(device)) if "vit" in model_name: # I think this is the clamping used by Lin Sprague Singh preprint image_features = torch.clamp(image_features, -1.5, 1.5) # normalize after clipping per the paper image_features = nn.functional.normalize(image_features, dim=-1) return image_features #print(model)Load data¶
NSD webdatasets for subjects 1, 2, and 3 are publicly available here:
https://huggingface.co/datasets/pscotti/naturalscenesdataset/tree/main/webdataset
xxxxxxxxxx# use large batches and the complete training dataset? full_training = Trueprint('full_training',full_training)xxxxxxxxxx# NAT_SCENE = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/"NAT_SCENE = "/home/jimgoo/data/neuro/naturalscenesdataset/webdataset/"# the tar files have a slightly different formatif "/scratch/gpfs/KNORMAN" in NAT_SCENE: SUBJ_FORMAT = "train_subj01_{{{}..{}}}.tar" SUBJ_FORMAT_VAL = "val_subj01_0.tar" VOXELS_KEY = 'nsdgeneral.npy'else: SUBJ_FORMAT = "subj01_nsdgeneral_{{{}..{}}}.tar" SUBJ_FORMAT_VAL = "val_subj01_nsdgeneral_0.tar" VOXELS_KEY = 'voxel.npy'xxxxxxxxxxSUBJ_FORMAT.format(0, 1)'subj01_nsdgeneral_{0..1}.tar'xxxxxxxxxx## things in one sample of data:# sample00000.voxel.npy# sample00000.voxel_3d.npy# sample00000.trial.npy# sample00000.sgxl_emb.npy# sample00000.jpgpreproc_vox, preproc_img = get_preprocs()# <TODO> check augmentation results before forward pass# image augmentation just for the CLIP image model that will be more semantic-focused# img_augment = transforms.Compose([# transforms.RandomCrop(size=(140,140)),# transforms.Resize(size=(224,224)),# transforms.RandomHorizontalFlip(p=.5),# transforms.ColorJitter(.4,.4,.2,.1),# transforms.RandomGrayscale(p=.2),# ])# <TODO> try more thingsimg_augment = transforms.Compose([ transforms.Resize(size=(224,224)), ])if not full_training: num_devices = 1 num_workers = 4 print("num_workers", num_workers) batch_size = 16 print("batch_size", batch_size) num_samples = 500 global_batch_size = batch_size * num_devices print("global_batch_size", global_batch_size) num_batches = math.floor(num_samples / global_batch_size) num_worker_batches = math.floor(num_batches / num_workers) print("num_worker_batches", num_worker_batches) train_url = os.path.join(NAT_SCENE, "train", SUBJ_FORMAT.format(0, 1)) else: # num_devices = torch.cuda.device_count() num_devices = 1 print("WARNING: num_devices hardcoded") print("num_devices", num_devices) # num_workers = num_devices * 4 num_workers = 1 # <TODO> switch back the above print("WARNING num_workers hardcoded") print("num_workers", num_workers) batch_size = 300 # batch_size = 1 # print("WARNING tiny batch size") print("batch_size",batch_size) num_samples = 24983 # see metadata.json in webdataset_split folder global_batch_size = batch_size * num_devices print("global_batch_size", global_batch_size) num_batches = math.floor(num_samples / global_batch_size) num_worker_batches = math.floor(num_batches / num_workers) print("num_worker_batches", num_worker_batches) train_url = os.path.join(NAT_SCENE, "train", SUBJ_FORMAT.format(0, 49))train_data = wds.DataPipeline([ # wds.ResampledShards(train_url), # <TODO> switch back to this once I understand it wds.SimpleShardList(train_url), wds.tarfile_to_samples(), # wds.shuffle(500, initial=500), # <TODO> this seems hardcoded for `full_training=False` wds.decode("torch"), #wds.rename(images="jpg;png", voxels=VOXELS_KEY, embs="sgxl_emb.npy", trial="trial.npy"), wds.rename(images="jpg;png", voxels=VOXELS_KEY), # <TODO> use less-lean version above wds.map_dict(images=preproc_img), wds.to_tuple("voxels", emb_name), wds.batched(batch_size, partial=True), ]) #.with_epoch(num_worker_batches) # <TODO> add this backtrain_dl = wds.WebLoader(train_data, num_workers=num_workers, batch_size=None, shuffle=False, persistent_workers=True)# Validation #num_samples = 492num_batches = math.ceil(num_samples / global_batch_size)num_worker_batches = math.ceil(num_batches / num_workers)print("validation: num_worker_batches", num_worker_batches)url = os.path.join(NAT_SCENE, "val", SUBJ_FORMAT_VAL)val_data = wds.DataPipeline([ # wds.ResampledShards(url), # <TODO> switch back to this once I understand it wds.SimpleShardList(url), wds.tarfile_to_samples(), wds.decode("torch"), # wds.rename(images="jpg;png", voxels=VOXELS_KEY, embs="sgxl_emb.npy", trial="trial.npy"), wds.rename(images="jpg;png", voxels=VOXELS_KEY), # <TODO> use less-lean version above wds.map_dict(images=preproc_img), wds.to_tuple("voxels", emb_name), wds.batched(batch_size, partial=True), ])#.with_epoch(num_worker_batches) # <TODO> add this backval_dl = wds.WebLoader(val_data, num_workers=num_workers, batch_size=None, shuffle=False, persistent_workers=True)WARNING: num_devices hardcoded num_devices 1 WARNING num_workers hardcoded num_workers 1 batch_size 300 global_batch_size 300 num_worker_batches 83 validation: num_worker_batches 2
xxxxxxxxxxdef test_loader(dl): # run through one batch and verify things are working for i, (voxel, emb) in enumerate(dl): print("idx", i) print("voxel.shape", voxel.shape) print("emb.shape", emb.shape) if emb_name=='images': # image embedding emb = emb.to(device) else: # text embedding text_tokens = text_tokenize(subj01_annots[emb]).to(device) emb = embedder(emb) print("emb.shape2", emb.shape) out_dim = emb.shape[1] print("out_dim", out_dim) break return out_dimxxxxxxxxxxout_dim = test_loader(train_dl)idx 0 voxel.shape torch.Size([300, 15724]) emb.shape torch.Size([300, 3, 224, 224]) emb.shape2 torch.Size([300, 768]) out_dim 768
xxxxxxxxxxout_dim = test_loader(val_dl)idx 0 voxel.shape torch.Size([300, 15724]) emb.shape torch.Size([300, 3, 224, 224]) emb.shape2 torch.Size([300, 768]) out_dim 768
xxxxxxxxxx# t0 = time.time()# n_batches = 0# for train_i, (voxel0, emb0) in enumerate(train_dl):# n_batches += 1# t1 = time.time()# # 84, 233.06136536598206# n_batches, t1-t0xxxxxxxxxx# t0 = time.time()# n_batches = 0# for val_i, (val_voxel0, val_emb0) in enumerate(val_dl):# n_batches += 1# t1 = time.time()# # (492, 3.9010021686553955)# n_batches, t1-t0xxxxxxxxxx# get the first batch of everythingfor train_i, (voxel0, emb0) in enumerate(train_dl): breakfor val_i, (val_voxel0, val_emb0) in enumerate(val_dl): breakxxxxxxxxxxvoxel0.shape, val_voxel0.shape(torch.Size([300, 15724]), torch.Size([300, 15724]))
xxxxxxxxxxemb0.shape, val_emb0.shape(torch.Size([300, 3, 224, 224]), torch.Size([300, 3, 224, 224]))
xxxxxxxxxxtorch_to_Image(emb0[0])xxxxxxxxxxtorch_to_Image(val_emb0[0])xxxxxxxxxx# <TODO> scale the voxels once I understand more about the format of trials inside the tar dataset filesV = voxel0.cpu().numpy()plt.plot(np.vstack((np.max(V, 0), np.mean(V, 0), np.min(V, 0))).T);plt.legend(['max', 'mean', 'min']);plt.xlabel('position in flattened voxel array');plt.ylabel('voxel value');xxxxxxxxxxpreproc_vox = transforms.Compose([transforms.ToTensor(),torch.nan_to_num])preproc_img = transforms.Compose([ transforms.Resize(size=(224,224)), transforms.Normalize(mean=mean, std=std), ])# image augmentation just for the CLIP image model that will be more semantic-focusedimg_augment = transforms.Compose([ transforms.RandomCrop(size=(140,140)), transforms.Resize(size=(224,224)), transforms.RandomHorizontalFlip(p=.5), transforms.ColorJitter(.4,.4,.2,.1), transforms.RandomGrayscale(p=.2), ])if not full_training: num_devices = 1 num_workers = 4 print("num_workers",num_workers) batch_size = 16 print("batch_size",batch_size) num_samples = 500 global_batch_size = batch_size * num_devices print("global_batch_size",global_batch_size) num_batches = math.floor(num_samples / global_batch_size) num_worker_batches = math.floor(num_batches / num_workers) print("num_worker_batches",num_worker_batches) train_url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/train/train_subj01_{0..1}.tar"else: num_devices = torch.cuda.device_count() print("num_devices",num_devices) num_workers = num_devices * 4 print("num_workers",num_workers) batch_size = 300 print("batch_size",batch_size) num_samples = 24983 # see metadata.json in webdataset_split folder global_batch_size = batch_size * num_devices print("global_batch_size",global_batch_size) num_batches = math.floor(num_samples / global_batch_size) num_worker_batches = math.floor(num_batches / num_workers) print("num_worker_batches",num_worker_batches) train_url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/train/train_subj01_{0..49}.tar"train_data = wds.DataPipeline([wds.ResampledShards(train_url), wds.tarfile_to_samples(), wds.shuffle(500,initial=500), wds.decode("torch"), wds.rename(images="jpg;png", voxels="nsdgeneral.npy", embs="sgxl_emb.npy", trial="trial.npy"), wds.map_dict(images=preproc_img), wds.to_tuple("voxels", emb_name), wds.batched(batch_size, partial=True), ]).with_epoch(num_worker_batches)train_dl = wds.WebLoader(train_data, num_workers=num_workers, batch_size=None, shuffle=False, persistent_workers=True)# Validation #num_samples = 492num_batches = math.ceil(num_samples / global_batch_size)num_worker_batches = math.ceil(num_batches / num_workers)print("validation: num_worker_batches",num_worker_batches)url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/val/val_subj01_0.tar"val_data = wds.DataPipeline([wds.ResampledShards(url), wds.tarfile_to_samples(), wds.decode("torch"), wds.rename(images="jpg;png", voxels="nsdgeneral.npy", embs="sgxl_emb.npy", trial="trial.npy"), wds.map_dict(images=preproc_img), wds.to_tuple("voxels", emb_name), wds.batched(batch_size, partial=True), ]).with_epoch(num_worker_batches)val_dl = wds.WebLoader(val_data, num_workers=num_workers, batch_size=None, shuffle=False, persistent_workers=True)# check that your data loaders are workingfor train_i, (voxel, emb) in enumerate(train_dl): print("idx",train_i) print("voxel.shape",voxel.shape) if emb_name=='images': # image embedding emb = emb.to(device) else: # text embedding text_tokens = text_tokenize(subj01_annots[emb]).to(device) print("emb.shape",emb.shape) emb = embedder(emb) print("emb.shape",emb.shape) out_dim = emb.shape[1] print("out_dim", out_dim) breaknum_devices 1 num_workers 4 batch_size 300 global_batch_size 300 num_worker_batches 20 validation: num_worker_batches 1 idx 0 voxel.shape torch.Size([300, 15724]) emb.shape torch.Size([300, 3, 224, 224]) emb.shape torch.Size([300, 768]) out_dim 768
Initialize network¶
xxxxxxxxxxclass BrainNetwork(nn.Module): def __init__(self, out_dim, h=7861): super().__init__() self.conv = nn.Sequential( nn.Conv1d(1, 32, kernel_size=3, stride=1, padding=0), nn.Dropout1d(0.1), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2) ) self.lin = nn.Linear(h,h) self.relu = nn.ReLU() self.lin1 = nn.Linear(251552,out_dim) def forward(self, x): x = x[:,None,:] x = self.conv(x) residual = x for res_block in range(4): x = self.lin(x) x += residual x = self.relu(x) residual = x x = x.reshape(len(x),-1) x = self.lin1(x) return x# PS note: i also tried the below network and it didn't work nearly as good at the top one# x = self.lin1(x)# return x xxxxxxxxxx# class BrainNetwork(nn.Module):# def __init__(self, out_dim, h=7861):# super().__init__()# self.conv = nn.Sequential(# nn.Conv1d(1, 32, kernel_size=3, stride=1, padding=0),# nn.Dropout1d(0.1),# nn.ReLU(),# nn.MaxPool1d(kernel_size=2, stride=2)# )# self.lin = nn.Linear(h, h)# self.relu = nn.ReLU()# self.lin1 = nn.Linear(251552, out_dim) # def forward(self, x):# #import ipdb; ipdb.set_trace()# # [300, 15724] -> [300, 1, 15724]# x = x[:, None, :] # # [300, 1, 15724] -> [300, 32, 7861]# x = self.conv(x)# residual = x# for res_block in range(4):# # same output shape# x = self.lin(x)# x += residual# x = self.relu(x)# residual = x# # [300, 32, 7861] -> [300, 251552]# x = x.reshape(len(x), -1)# x = self.lin1(x)# return x# PS note: i also tried the below network and it didn't work nearly as good at the top one# x = self.lin1(x)# return x class BrainNetwork(nn.Module): def __init__(self, out_dim, input_size=15724, h1=4096, h2=2048, h3=1024, pdrop=0.1, ): super().__init__() self.mlp = nn.Sequential( #torch.nn.BatchNorm1d(input_size), nn.Linear(input_size, h1), nn.ReLU(), nn.Dropout(pdrop), nn.Linear(h1, h2), nn.ReLU(), nn.Dropout(pdrop), nn.Linear(h2, h3), nn.ReLU(), nn.Dropout(pdrop), nn.Linear(h3, out_dim), ) def forward(self, x): return self.mlp(x)def param_count(model): """number of params in model""" return sum(p.numel() for p in model.parameters() if p.requires_grad)xxxxxxxxxx# reset rng seedtorch.manual_seed(123)np.random.seed(123)# init modelbrain_net = BrainNetwork(out_dim)# input_size = 15724# h1 = 4096# h2 = 2048# h3 = 1024# pdrop = 0.1# brain_net = nn.Sequential(# #torch.nn.BatchNorm1d(input_size),# nn.Linear(input_size, h1),# nn.ReLU(),# nn.Linear(h1, h2),# nn.ReLU(),# nn.Linear(h2, h3),# nn.ReLU(),# nn.Linear(h3, out_dim),# )# brain_net = nn.Sequential(# #torch.nn.BatchNorm1d(input_size),# nn.Linear(input_size, h1),# nn.ReLU(),# nn.Dropout(pdrop),# nn.Linear(h1, h2),# nn.ReLU(),# nn.Dropout(pdrop),# nn.Linear(h2, h3),# nn.ReLU(),# nn.Dropout(pdrop),# nn.Linear(h3, out_dim),# )print("{:,} params".format(param_count(brain_net)))brain_net75,685,632 params
BrainNetwork(
(mlp): Sequential(
(0): Linear(in_features=15724, out_features=4096, bias=True)
(1): ReLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=2048, bias=True)
(4): ReLU()
(5): Dropout(p=0.1, inplace=False)
(6): Linear(in_features=2048, out_features=1024, bias=True)
(7): ReLU()
(8): Dropout(p=0.1, inplace=False)
(9): Linear(in_features=1024, out_features=768, bias=True)
)
)xxxxxxxxxx# reset rng seedtorch.manual_seed(123)np.random.seed(123)# init modelbrain_net = BrainNetwork(out_dim)brain_net = brain_net.to(device)# test out that the neural network can run without error:with torch.cuda.amp.autocast(): out = brain_net(voxel.to(device)) print(out.shape)xxxxxxxxxxbrain_net = brain_net.to(device)# test out that the neural network can run without error:with torch.cuda.amp.autocast(): out = brain_net(voxel0.to(device)) print(out.shape)Train model¶
xxxxxxxxxxif full_training: num_epochs = 100else: num_epochs = 20initial_learning_rate = 1e-6optimizer = torch.optim.AdamW(brain_net.parameters(), lr=initial_learning_rate)# optimizer = torch.optim.SGD(brain_net.parameters(), lr=initial_learning_rate, momentum=0.95)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=1e-8, patience=5) nce = InfoNCE() # what we will use for loss function # Other losses to consider: #xxxxxxxxxxif full_training: num_epochs = 100else: num_epochs = 20 #initial_learning_rate = 1e-6initial_learning_rate = 3e-4# initial_learning_rate = 0.01#initial_learning_rate = 3e-3#print("WARNING - large learning rate", initial_learning_rate)optimizer = torch.optim.Adam(brain_net.parameters(), lr=initial_learning_rate)# optimizer = torch.optim.SGD(brain_net.parameters(), lr=initial_learning_rate, momentum=0.95)#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=1e-8, patience=5)# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)loss_fun = InfoNCE() # what we will use for loss function # loss_fun = nn.MSELoss()#loss_fun = nn.MSELoss()# Other losses to consider: #xxxxxxxxxxdef plot_training(): print(f"num_epochs:{num_epochs} batch_size:{batch_size} lr:{initial_learning_rate}") fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(17,3)) ax1.set_title(f"Training Loss\n(final={train_losses[-1]})") ax1.plot(train_losses) ax2.set_title(f"Training Performance\n(final={train_percent_correct[-1]})") ax2.plot(train_percent_correct) ax3.set_title(f"Val Loss\n(final={val_losses[-1]})") ax3.plot(val_losses) ax4.set_title(f"Val Performance\n(final={val_percent_correct[-1]})") ax4.plot(val_percent_correct) plt.show() def plot_preds(y, y_hat, title='', outdir=''): true = y.cpu().detach().numpy().T pred = y_hat.cpu().detach().numpy().T for i in range(y.shape[0]): plt.plot(np.vstack((true[:,i], pred[:,i])).T); plt.legend(['true', 'pred']); plt.title(title + ' sample %i' % i) if outdir: if not os.path.exists(outdir): os.makedirs(outdir) plt.savefig(outdir + '/%s-preds-sample-%i.jpeg' % (title, i)) plt.close() else: plt.show() def plot_err(y, y_hat, title=''): err = (y - y_hat) err = err.cpu().detach().numpy() plt.plot(err.T) plt.title(title) plt.show(); class AverageMeter: def __init__(self, name=None): self.name = name self.reset() def reset(self): self.sum = self.count = self.avg = 0 def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count xxxxxxxxxxoutdir = './checkpoints/v01'xxxxxxxxxx!rm -rf $outdir/*xxxxxxxxxx!mkdir -p $outdir!mkdir -p $outdir/preds/train/!mkdir -p $outdir/preds/val/xxxxxxxxxx!tree $outdir/./checkpoints/v01/ └── preds ├── train └── val 3 directories, 0 files
xxxxxxxxxxprint(f"num_epochs:{num_epochs} batch_size:{batch_size} lr:{initial_learning_rate}")print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))print(f"Will be saving model checkpoints to checkpoints/{model_name}_subj01_epoch#.pth")epoch = 0train_losses = []; val_losses = []train_percent_correct = []val_percent_correct = []lrs = []# # resuming from checkpoint?# lrs=checkpoint['lrs']pbar = tqdm(range(epoch,num_epochs))for epoch in pbar: brain_net.train() similarities = [] for train_i, (voxel, emb) in enumerate(train_dl): optimizer.zero_grad() voxel = voxel.to(device) with torch.cuda.amp.autocast(): if emb_name=='images': # image embedding if torch.any(torch.isnan(emb_)): raise ValueError("NaN found...") emb_ = nn.functional.normalize(emb_,dim=-1) # l2 normalization on the embeddings labels = torch.arange(len(emb)).to(device) loss = nce(emb_.reshape(len(emb),-1),emb.reshape(len(emb),-1)) similarities = batchwise_cosine_similarity(emb,emb_) percent_correct = topk(similarities,labels,k=1) loss.backward() optimizer.step() train_losses.append(loss.item()) train_percent_correct.append(percent_correct.item()) brain_net.eval() # using all validation samples to compute loss for val_i, (val_voxel, val_emb) in enumerate(val_dl): with torch.no_grad(): val_voxel = val_voxel.to(device) with torch.cuda.amp.autocast(): if emb_name=='images': # image embedding val_emb_ = brain_net(val_voxel) labels = torch.arange(len(val_emb)).to(device) val_loss = nce(val_emb_.reshape(len(val_emb),-1),val_emb.reshape(len(val_emb),-1)) val_similarities = batchwise_cosine_similarity(val_emb,val_emb_) percent_correct = topk(val_similarities,labels,k=1) val_losses.append(val_loss.item()) val_percent_correct.append(percent_correct.item()) if epoch%5==0 and full_training: torch.save({ 'epoch': epoch, 'model_state_dict': brain_net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_losses': train_losses, 'val_losses': val_losses, 'train_percent_correct': train_percent_correct, 'val_percent_correct': val_percent_correct, 'lrs': lrs, }, f'checkpoints/{model_name}_subj01_epoch{epoch}.pth') scheduler.step(val_loss) lrs.append(optimizer.param_groups[0]['lr']) pbar.set_description(f"Loss: {np.median(train_losses[-(train_i+1):]):.3f} | VLoss: {np.median(val_losses[-(val_i+1):]):.3f} | TopK%: {np.median(train_percent_correct[-10:]):.3f} | VTopK%: {np.median(val_percent_correct[-10:]):.3f} | lr{lrs[-1]:.5f}") print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))xxxxxxxxxxbs = 300print(f"num_epochs:{num_epochs} batch_size:{batch_size} lr:{initial_learning_rate}")print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))print(f"Will be saving model checkpoints to checkpoints/{model_name}_subj01_epoch#.pth")if not os.path.exists("checkpoints"): os.makedirs("checkpoints")epoch = 0train_losses = []; val_losses = []train_percent_correct = []val_percent_correct = []lrs = []epoch_logs = []# # resuming from checkpoint?# lrs=checkpoint['lrs']pbar = tqdm(range(epoch, num_epochs))for epoch in pbar: brain_net.train() train_loss_avg = AverageMeter() train_topk_avg = AverageMeter() val_loss_avg = AverageMeter() val_topk_avg = AverageMeter() for train_i, (voxel, emb) in enumerate(train_dl): #for train_i, (voxel, emb) in enumerate([(voxel0[:bs], emb0[:bs])]): # voxel = voxel0 # emb = emb0 bsz = voxel.shape[0] voxel = voxel.to(device) with torch.cuda.amp.autocast(): if emb_name=='images': # image embedding if torch.any(torch.isnan(emb_)): raise ValueError("NaN found...") emb_ = nn.functional.normalize(emb_, dim=-1) # l2 normalization on the embeddings labels = torch.arange(bsz).to(device) loss = loss_fun(emb_.reshape(bsz, -1), emb.reshape(bsz, -1)) similarities = batchwise_cosine_similarity(emb, emb_) percent_correct = topk(similarities, labels, k=1) optimizer.zero_grad() loss.backward() optimizer.step() train_losses.append(loss.item()) train_percent_correct.append(percent_correct.item()) train_loss_avg.update(loss.detach_(), bsz) train_topk_avg.update(percent_correct.detach_(), bsz) if train_i == 0 and epoch % 5 == 0: # plot_preds(emb[:2], emb_[:2], 'train', outdir + '/preds/epoch-%03d' % epoch) torch.save((emb_[:2], emb[:2]), outdir + '/preds/train/epoch-%03d.to' % epoch) # if train_i >= 0: # break brain_net.eval() # using all validation samples to compute loss for val_i, (val_voxel, val_emb) in enumerate(val_dl): #for val_i, (val_voxel, val_emb) in enumerate([(val_voxel0[:bs], val_emb0[:bs])]): # val_voxel = val_voxel0 # val_emb = val_emb0 bsz = val_voxel.shape[0] with torch.no_grad(): val_voxel = val_voxel.to(device) with torch.cuda.amp.autocast(): if emb_name=='images': # image embedding val_emb_ = brain_net(val_voxel) val_emb_ = nn.functional.normalize(val_emb_, dim=-1) # l2 normalization on the embeddings labels = torch.arange(bsz).to(device) val_loss = loss_fun(val_emb_.reshape(bsz,-1), val_emb.reshape(bsz,-1)) val_similarities = batchwise_cosine_similarity(val_emb, val_emb_) percent_correct = topk(val_similarities, labels, k=1) val_losses.append(val_loss.item()) val_percent_correct.append(percent_correct.item()) val_loss_avg.update(val_loss.detach_(), bsz) val_topk_avg.update(percent_correct.detach_(), bsz) if val_i == 0 and epoch % 5 == 0: # plot_preds(val_emb[:2], val_emb_[:2], 'val', outdir + '/preds/epoch-%03d' % epoch) torch.save((val_emb_[:2], val_emb[:2]), outdir + '/preds/val/epoch-%03d.to' % epoch) # if val_i >= 0: # break# if epoch % 5 == 0 and full_training:# torch.save({# 'epoch': epoch,# 'model_state_dict': brain_net.state_dict(),# 'optimizer_state_dict': optimizer.state_dict(),# 'train_losses': train_losses,# 'val_losses': val_losses,# 'train_percent_correct': train_percent_correct,# 'val_percent_correct': val_percent_correct,# 'lrs': lrs,# }, f'checkpoints/{model_name}_subj01_epoch{epoch}.pth') # <TODO> add back LR decay # scheduler.step(val_loss) lrs.append(optimizer.param_groups[0]['lr']) #pbar.set_description(f"Loss: {np.median(train_losses[-(train_i+1):]):.3f} | VLoss: {np.median(val_losses[-(val_i+1):]):.3f} | TopK%: {np.median(train_percent_correct[-10:]):.3f} | VTopK%: {np.median(val_percent_correct[-10:]):.3f} | lr{lrs[-1]:.5f}") logs = OrderedDict( loss=train_loss_avg.avg.item(), topk=train_topk_avg.avg.item(), val_loss=val_loss_avg.avg.item(), val_topk=val_topk_avg.avg.item(), lr=lrs[-1], ) epoch_logs.append(logs) pbar.set_postfix(**logs) pd.DataFrame(epoch_logs).to_csv(outdir + '/epoch-logs.csv') print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))xxxxxxxxxxplot_training()num_epochs:100 batch_size:300 lr:0.0003
xxxxxxxxxxpd.DataFrame(epoch_logs).plot(subplots=True);xxxxxxxxxxtrain_topk_avg.sum, train_topk_avg.count, train_topk_avg.avg(tensor(21086., device='cuda:0'), 24983, tensor(0.8440, device='cuda:0'))
xxxxxxxxxxtrain_percent_correct[-1], len(train_percent_correct), sum(train_percent_correct)(0.9036144018173218, 8400, 6814.349324496463)
xxxxxxxxxx## save modelxxxxxxxxxx#!rm -rf checkpoints/*.pthxxxxxxxxxxmodel_name'clip_image_vit'
xxxxxxxxxxepoch99
xxxxxxxxxxckpt_path = f'checkpoints/{model_name}_subj01_epoch{epoch}.pth'xxxxxxxxxxtorch.save({ 'epoch': epoch, 'model_state_dict': brain_net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_losses': train_losses, 'val_losses': val_losses, 'train_percent_correct': train_percent_correct, 'val_percent_correct': val_percent_correct, 'lrs': lrs, }, ckpt_path)xxxxxxxxxxprint(f"num_epochs:{num_epochs} batch_size:{batch_size} lr:{initial_learning_rate}")fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(17,3))ax1.set_title(f"Training Loss\n(final={train_losses[-1]})")ax1.plot(train_losses)ax2.set_title(f"Training Performance\n(final={train_percent_correct[-1]})")ax2.plot(train_percent_correct)ax3.set_title(f"Val Loss\n(final={val_losses[-1]})")ax3.plot(val_losses)ax4.set_title(f"Val Performance\n(final={val_percent_correct[-1]})")ax4.plot(val_percent_correct)plt.show()Plot losses from saved model¶
xxxxxxxxxx# Loading ckpt_path = 'checkpoints/clip_image_vit_subj01_epoch20.pth' checkpoint = torch.load(ckpt_path, map_location=device)print(f"Plotting results from {ckpt_path}")train_losses=checkpoint['train_losses']train_percent_correct=checkpoint['train_percent_correct']val_losses=checkpoint['val_losses']val_percent_correct=checkpoint['val_percent_correct']fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(17,3))ax1.set_title(f"Training Loss\n(final={train_losses[-1]})")ax1.plot(train_losses)ax2.set_title(f"Training Performance\n(final={train_percent_correct[-1]})")ax2.plot(train_percent_correct)ax3.set_title(f"Val Loss\n(final={val_losses[-1]})")ax3.plot(val_losses)ax4.set_title(f"Val Performance\n(final={val_percent_correct[-1]})")ax4.plot(val_percent_correct)plt.show()xxxxxxxxxxdef plot_saved(ckpt_path): # Loading # ckpt_path = 'checkpoints/clip_image_vit_subj01_epoch20.pth' checkpoint = torch.load(ckpt_path, map_location=device) print(f"Plotting results from {ckpt_path}") train_losses=checkpoint['train_losses'] train_percent_correct=checkpoint['train_percent_correct'] val_losses=checkpoint['val_losses'] val_percent_correct=checkpoint['val_percent_correct'] fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(17,3)) ax1.set_title(f"Training Loss\n(final={train_losses[-1]})") ax1.plot(train_losses) ax2.set_title(f"Training Performance\n(final={train_percent_correct[-1]})") ax2.plot(train_percent_correct) ax3.set_title(f"Val Loss\n(final={val_losses[-1]})") ax3.plot(val_losses) ax4.set_title(f"Val Performance\n(final={val_percent_correct[-1]})") ax4.plot(val_percent_correct) plt.show()xxxxxxxxxxplot_saved(ckpt_path)Plotting results from checkpoints/clip_image_vit_subj01_epoch99.pth
Evaluating Top-K Image Retrieval¶
Restart kernel, run "import packages & functions" and "initialize network" cells, and then run below cells.
xxxxxxxxxxnum_samples, batch_size, num_workers, num_worker_batches(492, 300, 1, 2)
xxxxxxxxxx# num_samples = 492# batch_size = 300# num_batches = 1# num_workers = 1# num_worker_batches = 1preproc_vox, preproc_img = get_preprocs()# url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/val/val_subj01_0.tar"# val_data = wds.DataPipeline([wds.ResampledShards(url),# wds.tarfile_to_samples(),# wds.decode("torch"),# wds.rename(images="jpg;png", voxels="nsdgeneral.npy", # embs="sgxl_emb.npy", trial="trial.npy"),# wds.map_dict(images=preproc_img),# wds.to_tuple("voxels", "images", "trial"),# wds.batched(batch_size, partial=True),# ]).with_epoch(num_worker_batches)# val_dl = wds.WebLoader(val_data, num_workers=num_workers,# batch_size=None, shuffle=False, persistent_workers=True)url = os.path.join(NAT_SCENE, "val", SUBJ_FORMAT_VAL)val_data = wds.DataPipeline([ # wds.ResampledShards(url), # <TODO> switch back to this once I understand it wds.SimpleShardList(url), wds.tarfile_to_samples(), wds.decode("torch"), wds.rename(images="jpg;png", voxels=VOXELS_KEY, embs="sgxl_emb.npy", trial="trial.npy"), wds.map_dict(images=preproc_img), wds.to_tuple("voxels", emb_name, "trial"), wds.batched(batch_size, partial=True), ]).with_epoch(1) #num_worker_batches)val_dl = wds.WebLoader(val_data, num_workers=num_workers, batch_size=None, shuffle=False, persistent_workers=True)xxxxxxxxxxnum_samples = 492batch_size = 300num_batches = 1num_workers = 1num_worker_batches = 1preproc_vox = transforms.Compose([transforms.ToTensor(),torch.nan_to_num])preproc_img = transforms.Compose([ transforms.Resize(size=(224,224)), transforms.Normalize(mean=mean, std=std), ])url = "/scratch/gpfs/KNORMAN/webdataset_nsd/webdataset_split/val/val_subj01_0.tar"val_data = wds.DataPipeline([wds.ResampledShards(url), wds.tarfile_to_samples(), wds.decode("torch"), wds.rename(images="jpg;png", voxels="nsdgeneral.npy", embs="sgxl_emb.npy", trial="trial.npy"), wds.map_dict(images=preproc_img), wds.to_tuple("voxels", "images", "trial"), wds.batched(batch_size, partial=True), ]).with_epoch(num_worker_batches)val_dl = wds.WebLoader(val_data, num_workers=num_workers, batch_size=None, shuffle=False, persistent_workers=True)xxxxxxxxxxclip_model, _ = clip.load("ViT-L/14", device=device)resnet_model, _ = clip.load("RN50", device=device)clip_model.eval()resnet_model.eval()f = h5py.File('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_subj_indices.hdf5', 'r')subj01_order = f['subj01'][:]f.close()# curated the COCO annotations in the same way as the mind_reader (Lin Sprague Singh) preprintannots = np.load('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_annots_curated.npy',allow_pickle=True)subj01_annots = annots[subj01_order]def text_tokenize(annots): for i,b in enumerate(annots): t = '' while t == '': rand = torch.randint(5,(1,1))[0][0] t = b[0,rand] if i==0: txt = np.array(t) else: txt = np.vstack((txt,t)) txt = txt.flatten() return clip.tokenize(txt)def clip_text_embedder(text_token): with torch.no_grad(): text_features = clip_model.encode_text(text_token.to(device)) return text_featuresdef clip_image_embedder(image): with torch.no_grad(): image_features = clip_model.encode_image(image.to(device)) image_features = torch.clamp(image_features,-1.5,1.5) return image_features def resnet_image_embedder(image): with torch.no_grad(): image_features = resnet_model.encode_image(image.to(device)) return image_features xxxxxxxxxxclip_model, _ = clip.load("ViT-L/14", device=device)# resnet_model, _ = clip.load("RN50", device=device)clip_model.eval()# resnet_model.eval()# f = h5py.File('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_subj_indices.hdf5', 'r')# subj01_order = f['subj01'][:]# f.close()# # curated the COCO annotations in the same way as the mind_reader (Lin Sprague Singh) preprint# annots = np.load('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_annots_curated.npy',allow_pickle=True)# subj01_annots = annots[subj01_order]# def text_tokenize(annots):# for i,b in enumerate(annots):# t = ''# while t == '':# rand = torch.randint(5,(1,1))[0][0]# t = b[0,rand]# if i==0:# txt = np.array(t)# else:# txt = np.vstack((txt,t))# txt = txt.flatten()# return clip.tokenize(txt)# def clip_text_embedder(text_token):# with torch.no_grad():# text_features = clip_model.encode_text(text_token.to(device))# return text_features# def clip_image_embedder(image):# with torch.no_grad():# image_features = clip_model.encode_image(image.to(device))# image_features = torch.clamp(image_features,-1.5,1.5) # return image_features def clip_image_embedder(image): assert model_name == 'clip_image_vit', model_name return embedder(image)# def resnet_image_embedder(image):# with torch.no_grad():# image_features = resnet_model.encode_image(image.to(device))# return image_featuresxxxxxxxxxxbrain_net = BrainNetwork(768) brain_net_clip_img = brain_net.to(device)checkpoint = torch.load('checkpoints/clip_image_vit_subj01_epoch20.pth', map_location=device)brain_net_clip_img.load_state_dict(checkpoint['model_state_dict'])brain_net_clip_img.eval()brain_net_clip_text = brain_net.to(device)checkpoint = torch.load('checkpoints/clip_text_vit_subj01_epoch20.pth', map_location=device)brain_net_clip_text.load_state_dict(checkpoint['model_state_dict'])brain_net_clip_text.eval()brain_net = BrainNetwork(1024) brain_net_resnet_img = brain_net.to(device)checkpoint = torch.load('checkpoints/clip_image_resnet_subj01_epoch42.pth', map_location=device)brain_net_resnet_img.load_state_dict(checkpoint['model_state_dict'])brain_net_resnet_img.eval()xxxxxxxxxxbrain_net = BrainNetwork(768) brain_net_clip_img = brain_net.to(device)# checkpoint = torch.load('checkpoints/clip_image_vit_subj01_epoch20.pth', map_location=device)checkpoint = torch.load(ckpt_path, map_location=device)brain_net_clip_img.load_state_dict(checkpoint['model_state_dict'])brain_net_clip_img.eval()# brain_net_clip_text = brain_net.to(device)# checkpoint = torch.load('checkpoints/clip_text_vit_subj01_epoch20.pth', map_location=device)# brain_net_clip_text.load_state_dict(checkpoint['model_state_dict'])# brain_net_clip_text.eval()# brain_net = BrainNetwork(1024) # brain_net_resnet_img = brain_net.to(device)# checkpoint = torch.load('checkpoints/clip_image_resnet_subj01_epoch42.pth', map_location=device)# brain_net_resnet_img.load_state_dict(checkpoint['model_state_dict'])# brain_net_resnet_img.eval()xxxxxxxxxx with torch.cuda.amp.autocast(): voxel = voxel.to(device) embt = text_tokenize(subj01_annots[trial_idx]).to(device) emb0=[]; emb1 = []; emb2 =[] for m in np.arange(0,batch_size,minibatch): if m==0: emb0 = clip_image_embedder(emb[m:m+minibatch]).detach().cpu() emb1 = resnet_image_embedder(emb[m:m+minibatch]).detach().cpu() emb2 = clip_text_embedder(embt[m:m+minibatch]).detach().cpu() else: emb0 = torch.vstack((emb0,clip_image_embedder(emb[m:m+minibatch]).detach().cpu())) emb1 = torch.vstack((emb1,resnet_image_embedder(emb[m:m+minibatch]).detach().cpu())) emb2 = torch.vstack((emb2,clip_text_embedder(embt[m:m+minibatch]).detach().cpu())) emb0 = emb0.to(device) emb1 = emb1.to(device) emb2 = emb2.to(device) emb_0 = brain_net_clip_img(voxel) emb_1 = brain_net_resnet_img(voxel) emb_2 = brain_net_clip_text(voxel) labels = torch.arange(len(emb0)).to(device) similarities0 = batchwise_cosine_similarity(emb0,emb_0) similarities1 = batchwise_cosine_similarity(emb1,emb_1) similarities2 = batchwise_cosine_similarity(emb2,emb_2) # how to combine the different models? similaritiesx = similarities0/2+similarities1+similarities2/2 print("CLIP IMG") plt.show() print("\nRESNET50 IMG") percent_correct = topk(similarities1,labels,k=1) print("percent_correct",percent_correct) similarities1=np.array(similarities1.detach().cpu()) for trial in range(4): fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6)) ax[0].imshow(torch_to_Image(emb[trial])) ax[0].set_title("original\nimage") ax[0].axis("off") for attempt in range(5): which = np.flip(np.argsort(similarities1[trial]))[attempt] ax[attempt+1].imshow(torch_to_Image(emb[which])) ax[attempt+1].set_title(f"Top {attempt}") ax[attempt+1].axis("off") plt.show() print("\nCLIP TEXT") percent_correct = topk(similarities2,labels,k=1) print("percent_correct",percent_correct) similarities2=np.array(similarities2.detach().cpu()) for trial in range(4): fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6)) ax[0].imshow(torch_to_Image(emb[trial])) ax[0].set_title("original\nimage") ax[0].axis("off") for attempt in range(5): which = np.flip(np.argsort(similarities2[trial]))[attempt] ax[attempt+1].imshow(torch_to_Image(emb[which])) ax[attempt+1].set_title(f"Top {attempt}") ax[attempt+1].axis("off") plt.show() print("\nCOMBINED") percent_correct = topk(similaritiesx,labels,k=1) print("percent_correct",percent_correct) similaritiesx=np.array(similaritiesx.detach().cpu()) for trial in range(4): fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6)) ax[0].imshow(torch_to_Image(emb[trial])) ax[0].set_title("original\nimage") ax[0].axis("off") for attempt in range(5): which = np.flip(np.argsort(similaritiesx[trial]))[attempt] ax[attempt+1].imshow(torch_to_Image(emb[which])) ax[attempt+1].set_title(f"Top {attempt}") ax[attempt+1].axis("off") plt.show()xxxxxxxxxx with torch.cuda.amp.autocast(): voxel = voxel.to(device) # embt = text_tokenize(subj01_annots[trial_idx]).to(device) emb0=[]; emb1 = []; emb2 =[] for m in np.arange(0,batch_size,minibatch): if m==0: emb0 = clip_image_embedder(emb[m:m+minibatch]).detach().cpu() # emb1 = resnet_image_embedder(emb[m:m+minibatch]).detach().cpu() # emb2 = clip_text_embedder(embt[m:m+minibatch]).detach().cpu() else: emb0 = torch.vstack((emb0,clip_image_embedder(emb[m:m+minibatch]).detach().cpu())) # emb1 = torch.vstack((emb1,resnet_image_embedder(emb[m:m+minibatch]).detach().cpu())) # emb2 = torch.vstack((emb2,clip_text_embedder(embt[m:m+minibatch]).detach().cpu())) emb0 = emb0.to(device) # emb1 = emb1.to(device) # emb2 = emb2.to(device) emb_0 = brain_net_clip_img(voxel) emb_0 = nn.functional.normalize(emb_0, dim=-1) # <TODO> move into network # emb_1 = brain_net_resnet_img(voxel) # emb_2 = brain_net_clip_text(voxel) labels = torch.arange(len(emb0)).to(device) similarities0 = batchwise_cosine_similarity(emb0,emb_0) # similarities1 = batchwise_cosine_similarity(emb1,emb_1) # similarities2 = batchwise_cosine_similarity(emb2,emb_2) # how to combine the different models? #similaritiesx = similarities0/2+similarities1+similarities2/2 print("CLIP IMG") plt.show() # print("\nRESNET50 IMG") # percent_correct = topk(similarities1,labels,k=1)# print("percent_correct",percent_correct) # similarities1=np.array(similarities1.detach().cpu())# for trial in range(4):# fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6))# ax[0].imshow(torch_to_Image(emb[trial]))# ax[0].set_title("original\nimage")# ax[0].axis("off")# for attempt in range(5):# which = np.flip(np.argsort(similarities1[trial]))[attempt]# ax[attempt+1].imshow(torch_to_Image(emb[which]))# ax[attempt+1].set_title(f"Top {attempt}")# ax[attempt+1].axis("off")# plt.show() # print("\nCLIP TEXT") # percent_correct = topk(similarities2,labels,k=1)# print("percent_correct",percent_correct) # similarities2=np.array(similarities2.detach().cpu())# for trial in range(4):# fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6))# ax[0].imshow(torch_to_Image(emb[trial]))# ax[0].set_title("original\nimage")# ax[0].axis("off")# for attempt in range(5):# which = np.flip(np.argsort(similarities2[trial]))[attempt]# ax[attempt+1].imshow(torch_to_Image(emb[which]))# ax[attempt+1].set_title(f"Top {attempt}")# ax[attempt+1].axis("off")# plt.show() # print("\nCOMBINED") # percent_correct = topk(similaritiesx,labels,k=1)# print("percent_correct",percent_correct) # similaritiesx=np.array(similaritiesx.detach().cpu())# for trial in range(4):# fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(11,6))# ax[0].imshow(torch_to_Image(emb[trial]))# ax[0].set_title("original\nimage")# ax[0].axis("off")# for attempt in range(5):# which = np.flip(np.argsort(similaritiesx[trial]))[attempt]# ax[attempt+1].imshow(torch_to_Image(emb[which]))# ax[attempt+1].set_title(f"Top {attempt}")# ax[attempt+1].axis("off")# plt.show()xxxxxxxxxx